/*
 * Copyright 2015, The Querydsl Team (http://www.querydsl.com/team)
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * http://www.apache.org/licenses/LICENSE-2.0
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.querydsl.sql.dml;

import java.sql.*;
import java.util.Collection;
import java.util.List;
import java.util.Map;

import javax.annotation.Nullable;
import javax.inject.Provider;

import org.slf4j.Logger;
import org.slf4j.MDC;

import com.querydsl.core.QueryMetadata;
import com.querydsl.core.dml.DMLClause;
import com.querydsl.core.support.QueryBase;
import com.querydsl.core.types.ParamExpression;
import com.querydsl.core.types.ParamNotSetException;
import com.querydsl.core.types.Path;
import com.querydsl.sql.*;

import static com.google.common.collect.Lists.newArrayList;

/**
 * {@code AbstractSQLClause} is a superclass for SQL based DMLClause implementations
 *
 * @param <C> concrete subtype
 *
 * @author tiwe
 */
public abstract class AbstractSQLClause<C extends AbstractSQLClause<C>> implements DMLClause<C> {

    protected final Configuration configuration;

    protected final SQLListeners listeners;

    protected boolean useLiterals;

    protected SQLListenerContextImpl context;

    @Nullable
    private Provider<Connection> connProvider;

    @Nullable
    private Connection conn;

    public AbstractSQLClause(Configuration configuration) {
        this.configuration = configuration;
        this.listeners = new SQLListeners(configuration.getListeners());
        this.useLiterals = configuration.getUseLiterals();
    }

    public AbstractSQLClause(Configuration configuration, Provider<Connection> connProvider) {
        this(configuration);
        this.connProvider = connProvider;
    }

    public AbstractSQLClause(Configuration configuration, Connection conn) {
        this(configuration);
        this.conn = conn;
    }

    /**
     * Add a listener
     *
     * @param listener listener to add
     */
    public void addListener(SQLListener listener) {
        listeners.add(listener);
    }

    /**
     * Clear the internal state of the clause
     */
    public abstract void clear();

    /**
     * Called to create and start a new SQL Listener context
     *
     * @param connection the database connection
     * @param metadata   the meta data for that context
     * @param entity     the entity for that context
     * @return the newly started context
     */
    protected SQLListenerContextImpl startContext(Connection connection, QueryMetadata metadata, RelationalPath<?> entity) {
        SQLListenerContextImpl context = new SQLListenerContextImpl(metadata, connection, entity);
        listeners.start(context);
        return context;
    }

    /**
     * Called to make the call back to listeners when an exception happens
     *
     * @param context the current context in play
     * @param e       the exception
     */
    protected void onException(SQLListenerContextImpl context, Exception e) {
        context.setException(e);
        listeners.exception(context);
    }

    /**
     * Called to end a SQL listener context
     *
     * @param context the listener context to end
     */
    protected void endContext(SQLListenerContextImpl context) {
        listeners.end(context);
        this.context = null;
    }


    protected SQLBindings createBindings(QueryMetadata metadata, SQLSerializer serializer) {
        String queryString = serializer.toString();
        List<Object> args = newArrayList();
        Map<ParamExpression<?>, Object> params = metadata.getParams();
        for (Object o : serializer.getConstants()) {
            if (o instanceof ParamExpression) {
                if (!params.containsKey(o)) {
                    throw new ParamNotSetException((ParamExpression<?>) o);
                }
                o = metadata.getParams().get(o);
            }
            args.add(o);
        }
        return new SQLBindings(queryString, args);
    }

    protected SQLSerializer createSerializer() {
        SQLSerializer serializer = new SQLSerializer(configuration, true);
        serializer.setUseLiterals(useLiterals);
        return serializer;
    }

    /**
     * Get the SQL string and bindings
     *
     * @return SQL and bindings
     */
    public abstract List<SQLBindings> getSQL();

    /**
     * Set the parameters to the given PreparedStatement
     *
     * @param stmt preparedStatement to be populated
     * @param objects list of constants
     * @param constantPaths list of paths related to the constants
     * @param params map of param to value for param resolving
     */
    protected void setParameters(PreparedStatement stmt, List<?> objects,
            List<Path<?>> constantPaths, Map<ParamExpression<?>, ?> params) {
        if (objects.size() != constantPaths.size()) {
            throw new IllegalArgumentException("Expected " + objects.size() + " paths, " +
                    "but got " + constantPaths.size());
        }
        for (int i = 0; i < objects.size(); i++) {
            Object o = objects.get(i);
            try {
                if (o instanceof ParamExpression) {
                    if (!params.containsKey(o)) {
                        throw new ParamNotSetException((ParamExpression<?>) o);
                    }
                    o = params.get(o);
                }
                configuration.set(stmt, constantPaths.get(i), i + 1, o);
            } catch (SQLException e) {
                throw configuration.translate(e);
            }
        }
    }

    private long executeBatch(PreparedStatement stmt) throws SQLException {
        if (configuration.getUseLiterals()) {
            return stmt.executeUpdate();
        } else if (configuration.getTemplates().isBatchCountViaGetUpdateCount()) {
            stmt.executeBatch();
            return stmt.getUpdateCount();
        } else {
            long rv = 0;
            for (int i : stmt.executeBatch()) {
                rv += i;
            }
            return rv;
        }
    }

    protected long executeBatch(Collection<PreparedStatement> stmts) throws SQLException {
        long rv = 0;
        for (PreparedStatement stmt : stmts) {
            rv += executeBatch(stmt);
        }
        return rv;
    }

    protected void close(Statement stmt) {
        try {
            stmt.close();
        } catch (SQLException e) {
            throw configuration.translate(e);
        }
    }

    protected void close(Collection<? extends Statement> stmts) {
        for (Statement stmt : stmts) {
            close(stmt);
        }
    }

    protected void close(ResultSet rs) {
        try {
            rs.close();
        } catch (SQLException e) {
            throw configuration.translate(e);
        }
    }

    protected void logQuery(Logger logger, String queryString, Collection<Object> parameters) {
        if (logger.isDebugEnabled()) {
            String normalizedQuery = queryString.replace('\n', ' ');
            MDC.put(QueryBase.MDC_QUERY, normalizedQuery);
            MDC.put(QueryBase.MDC_PARAMETERS, String.valueOf(parameters));
            logger.debug(normalizedQuery);
        }
    }

    protected void cleanupMDC() {
        MDC.remove(QueryBase.MDC_QUERY);
        MDC.remove(QueryBase.MDC_PARAMETERS);
    }

    protected void reset() {
        cleanupMDC();
    }

    protected Connection connection() {
        if (conn == null) {
            if (connProvider != null) {
                conn = connProvider.get();
            } else {
                throw new IllegalStateException("No connection provided");
            }
        }
        return conn;
    }

    public void setUseLiterals(boolean useLiterals) {
        this.useLiterals = useLiterals;
    }

    public abstract int getBatchCount();

}
